#!/usr/bin/env python3

#####################################################################
#
# watermarkstat is a tool for displaying watermarks.
#
#####################################################################

import argparse
import json
import os
import sys

from natsort import natsorted
from tabulate import tabulate

# mock the redis for unit test purposes #
try:
    if os.environ["UTILITIES_UNIT_TESTING"] == "2":
        modules_path = os.path.join(os.path.dirname(__file__), "..")
        tests_path = os.path.join(modules_path, "tests")
        sys.path.insert(0, modules_path)
        sys.path.insert(0, tests_path)
        from mock_tables import dbconnector

        if os.environ["WATERMARKSTAT_UNIT_TESTING"] == "1":
            input_path = os.path.join(tests_path, "wm_input")
            mock_db_path = os.path.join(input_path, "mock_db")

            for mock_db in os.listdir(mock_db_path):
                name, ext = os.path.splitext(mock_db)
                if ext == ".json":
                    dbconnector.dedicated_dbs[name.upper()] = os.path.join(mock_db_path, name)
except KeyError:
    pass

from swsscommon.swsscommon import SonicV2Connector


headerBufferPool = ['Pool', 'Bytes']


STATUS_NA = 'N/A'
STATUS_INVALID = 'INVALID'

QUEUE_TYPE_MC = 'MC'
QUEUE_TYPE_UC = 'UC'
QUEUE_TYPE_ALL = 'ALL'
SAI_QUEUE_TYPE_MULTICAST = "SAI_QUEUE_TYPE_MULTICAST"
SAI_QUEUE_TYPE_UNICAST = "SAI_QUEUE_TYPE_UNICAST"
SAI_QUEUE_TYPE_ALL = "SAI_QUEUE_TYPE_ALL"

COUNTER_TABLE_PREFIX = "COUNTERS:"
PERSISTENT_TABLE_PREFIX = "PERSISTENT_WATERMARKS:"
PERIODIC_TABLE_PREFIX = "PERIODIC_WATERMARKS:"
USER_TABLE_PREFIX = "USER_WATERMARKS:"

COUNTERS_PORT_NAME_MAP = "COUNTERS_PORT_NAME_MAP"
COUNTERS_QUEUE_NAME_MAP = "COUNTERS_QUEUE_NAME_MAP"
COUNTERS_QUEUE_TYPE_MAP = "COUNTERS_QUEUE_TYPE_MAP"
COUNTERS_QUEUE_INDEX_MAP = "COUNTERS_QUEUE_INDEX_MAP"
COUNTERS_QUEUE_PORT_MAP = "COUNTERS_QUEUE_PORT_MAP"
COUNTERS_PG_NAME_MAP = "COUNTERS_PG_NAME_MAP"
COUNTERS_PG_PORT_MAP = "COUNTERS_PG_PORT_MAP"
COUNTERS_PG_INDEX_MAP = "COUNTERS_PG_INDEX_MAP"
COUNTERS_BUFFER_POOL_NAME_MAP = "COUNTERS_BUFFER_POOL_NAME_MAP"


class Watermarkstat(object):

    def __init__(self):
        self.counters_db = SonicV2Connector(use_unix_socket_path=False)
        self.counters_db.connect(self.counters_db.COUNTERS_DB)

        # connect APP DB for clear notifications
        self.app_db = SonicV2Connector(use_unix_socket_path=False)
        self.app_db.connect(self.counters_db.APPL_DB)

        def get_queue_type(table_id):
            queue_type = self.counters_db.get(self.counters_db.COUNTERS_DB, COUNTERS_QUEUE_TYPE_MAP, table_id)
            if queue_type is None:
                print("Queue Type is not available in table '{}'".format(table_id), file=sys.stderr)
                sys.exit(1)
            elif queue_type == SAI_QUEUE_TYPE_MULTICAST:
                return QUEUE_TYPE_MC
            elif queue_type == SAI_QUEUE_TYPE_UNICAST:
                return QUEUE_TYPE_UC
            elif queue_type == SAI_QUEUE_TYPE_ALL:
                return QUEUE_TYPE_ALL
            else:
                print("Queue Type '{} in table '{}' is invalid".format(queue_type, table_id), file=sys.stderr)
                sys.exit(1)

        def get_queue_port(table_id):
            port_table_id = self.counters_db.get(self.counters_db.COUNTERS_DB, COUNTERS_QUEUE_PORT_MAP, table_id)
            if port_table_id is None:
                print("Port is not available in table '{}'".format(table_id), file=sys.stderr)
                sys.exit(1)

            return port_table_id

        def get_pg_port(table_id):
            port_table_id = self.counters_db.get(self.counters_db.COUNTERS_DB, COUNTERS_PG_PORT_MAP, table_id)
            if port_table_id is None:
                print("Port is not available in table '{}'".format(table_id), file=sys.stderr)
                sys.exit(1)

            return port_table_id

        # Get all ports
        self.counter_port_name_map = self.counters_db.get_all(self.counters_db.COUNTERS_DB, COUNTERS_PORT_NAME_MAP)
        if self.counter_port_name_map is None:
            print("COUNTERS_PORT_NAME_MAP is empty!", file=sys.stderr)
            sys.exit(1)

        self.port_uc_queues_map = {}
        self.port_mc_queues_map = {}
        self.port_all_queues_map = {}
        self.port_pg_map = {}
        self.port_name_map = {}

        for port in self.counter_port_name_map:
            self.port_uc_queues_map[port] = {}
            self.port_mc_queues_map[port] = {}
            self.port_all_queues_map[port] = {}
            self.port_pg_map[port] = {}
            self.port_name_map[self.counter_port_name_map[port]] = port

        # Get Queues for each port
        counter_queue_name_map = self.counters_db.get_all(self.counters_db.COUNTERS_DB, COUNTERS_QUEUE_NAME_MAP)
        if counter_queue_name_map is None:
            print("COUNTERS_QUEUE_NAME_MAP is empty!", file=sys.stderr)
            sys.exit(1)

        for queue in counter_queue_name_map:
            port = self.port_name_map[get_queue_port(counter_queue_name_map[queue])]
            if get_queue_type(counter_queue_name_map[queue]) == QUEUE_TYPE_UC:
                self.port_uc_queues_map[port][queue] = counter_queue_name_map[queue]

            elif get_queue_type(counter_queue_name_map[queue]) == QUEUE_TYPE_MC:
                self.port_mc_queues_map[port][queue] = counter_queue_name_map[queue]

            elif get_queue_type(counter_queue_name_map[queue]) == QUEUE_TYPE_ALL:
                self.port_all_queues_map[port][queue] = counter_queue_name_map[queue]

        # Get PGs for each port
        counter_pg_name_map = self.counters_db.get_all(self.counters_db.COUNTERS_DB, COUNTERS_PG_NAME_MAP)
        if counter_pg_name_map is None:
            print("COUNTERS_PG_NAME_MAP is empty!", file=sys.stderr)
            sys.exit(1)

        for pg in counter_pg_name_map:
            port = self.port_name_map[get_pg_port(counter_pg_name_map[pg])]
            self.port_pg_map[port][pg] = counter_pg_name_map[pg]

        # Get all buffer pools
        self.buffer_pool_name_to_oid_map = self.counters_db.get_all(self.counters_db.COUNTERS_DB, COUNTERS_BUFFER_POOL_NAME_MAP)
        if self.buffer_pool_name_to_oid_map is None:
            print("COUNTERS_BUFFER_POOL_NAME_MAP is empty!", file=sys.stderr)
            sys.exit(1)

        self.watermark_types = {
            "pg_headroom"   : {"message" : "Ingress headroom per PG:",
                               "obj_map" : self.port_pg_map,
                               "idx_func": self.get_pg_index,
                               "wm_name" : "SAI_INGRESS_PRIORITY_GROUP_STAT_XOFF_ROOM_WATERMARK_BYTES",
                               "header_prefix": "PG"},
            "pg_shared"     : {"message" : "Ingress shared pool occupancy per PG:",
                               "obj_map" : self.port_pg_map,
                               "idx_func": self.get_pg_index,
                               "wm_name" : "SAI_INGRESS_PRIORITY_GROUP_STAT_SHARED_WATERMARK_BYTES",
                               "header_prefix": "PG"},
            "q_shared_uni"  : {"message" : "Egress shared pool occupancy per unicast queue:",
                               "obj_map" : self.port_uc_queues_map,
                               "idx_func": self.get_queue_index,
                               "wm_name" : "SAI_QUEUE_STAT_SHARED_WATERMARK_BYTES",
                               "header_prefix": "UC"},
            "q_shared_multi": {"message" : "Egress shared pool occupancy per multicast queue:",
                               "obj_map" : self.port_mc_queues_map,
                               "idx_func": self.get_queue_index,
                               "wm_name" : "SAI_QUEUE_STAT_SHARED_WATERMARK_BYTES",
                               "header_prefix": "MC"},
            "q_shared_all":   {"message" : "Egress shared pool occupancy per all queues:",
                               "obj_map" : self.port_all_queues_map,
                               "idx_func": self.get_queue_index,
                               "wm_name" : "SAI_QUEUE_STAT_SHARED_WATERMARK_BYTES",
                               "header_prefix": "ALL"},
            "buffer_pool"   : {"message": "Shared pool maximum occupancy:",
                               "wm_name": "SAI_BUFFER_POOL_STAT_WATERMARK_BYTES",
                               "header" : headerBufferPool},
            "headroom_pool" : {"message": "Headroom pool maximum occupancy:",
                               "wm_name": "SAI_BUFFER_POOL_STAT_XOFF_ROOM_WATERMARK_BYTES",
                               "header" : headerBufferPool}
        }

    def get_queue_index(self, table_id):
        queue_index = self.counters_db.get(self.counters_db.COUNTERS_DB, COUNTERS_QUEUE_INDEX_MAP, table_id)
        if queue_index is None:
            print("Queue index is not available in table '{}'".format(table_id), file=sys.stderr)
            sys.exit(1)

        return queue_index

    def get_pg_index(self, table_id):
        pg_index = self.counters_db.get(self.counters_db.COUNTERS_DB, COUNTERS_PG_INDEX_MAP, table_id)
        if pg_index is None:
            print("Priority group index is not available in table '{}'".format(table_id), file=sys.stderr)
            sys.exit(1)

        return pg_index

    def build_header(self, wm_type, counter_type):
        if wm_type is None:
            print("Header info is not available!", file=sys.stderr)
            sys.exit(1)

        self.header_list = ['Port']
        header_map = wm_type["obj_map"]

        max_idx = 0
        min_idx = sys.maxsize
        for port in header_map.keys():
            for element in header_map[port].keys():
                element_idx = int(element.split(':')[1])
                if element_idx > max_idx:
                    max_idx = element_idx
                if min_idx > element_idx:
                    min_idx = element_idx

        if min_idx == sys.maxsize:
            if counter_type != 'q_shared_multi':
                print("Object map is empty!", file=sys.stderr)
                sys.exit(1)
            else:
                print("Object map from the COUNTERS_DB is empty because the multicast queues are not configured in the CONFIG_DB!")
                sys.exit(0)

        self.min_idx = min_idx
        self.header_list += ["{}{}".format(wm_type["header_prefix"], idx) for idx in range(self.min_idx, max_idx + 1)]

    def get_counters(self, table_prefix, port_obj, idx_func, watermark):
        """
            Get the counters from specific table.
        """

        # header list contains the port name followed by the queues/pgs. fields is used to populate the queue/pg values
        fields = ["0"]* (len(self.header_list) - 1)
        if not fields:
            # counters are not enabled.
            return fields

        for name, obj_id in port_obj.items():
            full_table_id = table_prefix + obj_id
            idx = int(idx_func(obj_id))
            pos = idx - self.min_idx
            counter_data = self.counters_db.get(self.counters_db.COUNTERS_DB, full_table_id, watermark)
            if counter_data is None or counter_data == '':
                fields[pos] = STATUS_NA
            elif fields[pos] != STATUS_NA:
                fields[pos] = str(int(counter_data))
        return fields

    def print_all_stat(self, table_prefix, key):
        table = []
        type = self.watermark_types[key]
        if key in ['buffer_pool', 'headroom_pool']:
            self.header_list = type['header']
            # Get stats for each buffer pool
            for buf_pool, bp_oid in natsorted(self.buffer_pool_name_to_oid_map.items()):
                if key == 'headroom_pool' and 'ingress_lossless' not in buf_pool:
                    continue

                db_key = table_prefix + bp_oid
                data = self.counters_db.get(self.counters_db.COUNTERS_DB, db_key, type["wm_name"])
                if data is None:
                    data = STATUS_NA
                table.append((buf_pool, data))
        else:
            self.build_header(type, key)
            # Get stat for each port
            for port in natsorted(self.counter_port_name_map):
                row_data = list()
                data = self.get_counters(table_prefix,
                                         type["obj_map"][port], type["idx_func"], type["wm_name"])
                row_data.append(port)
                row_data.extend(data)
                table.append(tuple(row_data))

        print(type["message"])
        print(tabulate(table, self.header_list, tablefmt='simple', stralign='right'))

    def send_clear_notification(self, data):
        msg = json.dumps(data, separators=(',', ':'))
        self.app_db.publish('APPL_DB', 'WATERMARK_CLEAR_REQUEST', msg)
        return


def main():

    parser = argparse.ArgumentParser(description='Display the watermark counters',
                                      formatter_class=argparse.RawTextHelpFormatter,
                                      epilog="""
Examples:
  watermarkstat -t pg_headroom
  watermarkstat -t pg_shared
  watermarkstat -t q_shared_all
  watermarkstat -p -t q_shared_all
  watermarkstat -t q_shared_all -c
  watermarkstat -t q_shared_uni -c
  watermarkstat -t q_shared_multi -c
  watermarkstat -p -t pg_shared
  watermarkstat -p -t q_shared_multi -c
  watermarkstat -t buffer_pool
  watermarkstat -t buffer_pool -c
  watermarkstat -p -t buffer_pool -c
""")

    parser.add_argument('-c', '--clear', action='store_true', help='Clear watermarks request')
    parser.add_argument('-p', '--persistent', action='store_true', help='Do the operations on the persistent watermark')
    parser.add_argument('-t', '--type', required=True, action='store',
                        choices=['pg_headroom', 'pg_shared', 'q_shared_uni', 'q_shared_multi', 'buffer_pool', 'headroom_pool', 'q_shared_all'],
                        help='The type of watermark')
    parser.add_argument('-v', '--version', action='version', version='%(prog)s 1.0')
    args = parser.parse_args()
    watermarkstat = Watermarkstat()

    if args.clear:
        watermarkstat.send_clear_notification(("PERSISTENT" if args.persistent else "USER", args.type.upper()))
        sys.exit(0)

    table_prefix = PERSISTENT_TABLE_PREFIX if args.persistent else USER_TABLE_PREFIX
    watermarkstat.print_all_stat(table_prefix, args.type)
    sys.exit(0)


if __name__ == "__main__":
    main()
